In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import torch
import os

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import cv2
import matplotlib.pyplot as plt
C:\Users\lisas\anaconda3\envs\xai_model_explanation\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [2]:
predictions = pd.read_csv('predictions.csv')
In [3]:
labels = torch.load(os.path.join('data','y_test.pt'))
predictions['label'] = pd.Series(labels)
In [4]:
correct_pred = [np.sum(np.array(predictions.iloc[i,:50]) == predictions.iloc[i,50])/50 for i in range(len(predictions))]
predictions['% correct'] = pd.Series(correct_pred)
In [5]:
predictions['index'] = pd.Series(list(range(len(predictions))))
In [6]:
frequencies = []
variabilities = []
scores = []
for i in range(len(predictions)):
    frequency = 0
    score = 0
    for j in range(49):
        if predictions[str(j+1)].iloc[i] != predictions[str(j+2)].iloc[i]:
            frequency += 1
        if predictions[str(j+1)].iloc[i] != predictions['label'].iloc[i]:
            score += 1
    frequencies.append(frequency/49)
    variability = len(set(predictions.iloc[i,:50]))/6
    variabilities.append(variability)
    if predictions['50'].iloc[i] != predictions['label'].iloc[i]:
        score += 1
    scores.append(score/50)
predictions['frequency'] = pd.Series(frequencies)
predictions['variability'] = pd.Series(variabilities)
predictions['misclassification'] = pd.Series(scores)
In [7]:
# sort predictions

start = 0
for label in range(6):
    sub_frame = predictions.loc[predictions['label'] == label].sort_values('% correct', ascending=False)
    predictions.iloc[start:start+len(sub_frame),:] = sub_frame
    start += len(sub_frame)
predictions = predictions[::-1]
In [8]:
predictions.head()
Out[8]:
1 2 3 4 5 6 7 8 9 10 ... 47 48 49 50 label % correct index frequency variability misclassification
2992 0 0 0 0 0 0 0 0 0 0 ... 4 4 4 4 5 0.0 2831 0.326531 0.500000 1.0
2991 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 5 0.0 2773 0.122449 0.333333 1.0
2990 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 5 0.0 2493 0.000000 0.166667 1.0
2989 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 5 0.0 2944 0.306122 0.333333 1.0
2988 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 5 0.0 2636 0.000000 0.166667 1.0

5 rows × 56 columns

In [9]:
test_acc = []
for i in range(50):
    acc = np.sum(np.array(predictions[str(i+1)]) == np.array(predictions['label']))/len(predictions)*100
    test_acc.append(acc)
In [10]:
val_loss = torch.load("training\\metrics\\validation_loss.pt")
train_loss = torch.load("training\\metrics\\train_loss.pt")
train_acc = torch.load("training\\metrics\\train_accuracy.pt")
In [11]:
def discrete_colorscale(bvals, colors):
    """
    bvals - list of values bounding intervals/ranges of interest
    colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
    returns the plotly  discrete colorscale
    """
    if len(bvals) != len(colors)+1:
        raise ValueError('len(boundary values) should be equal to  len(colors)+1')
    bvals = sorted(bvals)     
    nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals]  #normalized values
    
    dcolorscale = [] #discrete colorscale
    for k in range(len(colors)):
        dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
    return dcolorscale    
In [12]:
bvals = [-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
colors = ['red', 'green', 'lightblue', 'yellow', 'darkblue', 'grey']

dcolorsc = discrete_colorscale(bvals, colors)

heatmap = go.Heatmap(x = list(range(1,len(predictions)+1)), z=predictions.iloc[:,:46], colorscale = dcolorsc, showscale = False)
In [13]:
heatmap_scale = go.Heatmap(x = list(range(1,len(predictions)+1)), z=np.array(predictions['label']).reshape((-1,1)), colorscale = dcolorsc, showscale = False)
In [14]:
freq = go.Bar(x=predictions['frequency'], orientation='h')
var = go.Bar(x=predictions['variability'], orientation='h')
miss = go.Bar(x=predictions['misclassification'], orientation='h')
In [15]:
val_loss_trace = go.Scatter({'x': list(range(1,47)),'y': val_loss[:46]}, line = dict(color='orange'))
train_loss_trace = go.Scatter({'x': list(range(1,47)),'y': train_loss[:46]}, line = dict(color='purple'))
In [16]:
test_acc_trace = go.Scatter({'x': list(range(1,47)),'y': test_acc[:46]}, line = dict(color='darkorange'))
train_acc_trace = go.Scatter({'x': list(range(1,47)),'y': train_acc[:46]}, line = dict(color='purple'))
In [17]:
confusion_matrix = np.array(pd.crosstab(predictions['label'], predictions['46']))
cm = go.Heatmap({'z': confusion_matrix/np.sum(confusion_matrix, axis=0), 
                 'x':['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'],
                 'y':['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'],
                 'text': confusion_matrix/np.sum(confusion_matrix, axis=0),
                 'colorscale': 'greys'}, showscale = False)
In [18]:
fig = make_subplots(
    rows=3, cols=5,
    specs=[[{}, {}, {}, {}, {}],
           [None, {}, {"rowspan": 2, "colspan": 3}, None, None],
           [None, {}, None, None, None]],
    row_heights=[0.7, 0.15, 0.15], column_widths=[0.02, 0.7, 0.1, 0.1, 0.1],
    horizontal_spacing = 0.01, vertical_spacing = 0.06,
    subplot_titles=('Labels', 'Confusion Evolution',  'Frequency', 'Variability', 'Misclassification Rate', 'Loss', 'Confusion Matrix', 'Accuracy'),
    print_grid=False)

fig.add_trace(heatmap_scale, row=1, col=1)
fig.add_trace(heatmap, row=1, col=2)

fig.add_trace(freq, row=1, col=3)
fig.add_trace(var, row=1, col=4)
fig.add_trace(miss, row=1, col=5)

fig.add_trace(val_loss_trace, row=2, col=2)
fig.add_trace(train_loss_trace, row=2, col=2)

fig.add_trace(test_acc_trace, row=3, col=2)
fig.add_trace(train_acc_trace, row=3, col=2)

fig.add_trace(cm, row = 2, col = 3)

fig.update_layout(plot_bgcolor='white', height=1500, width=1500,
                  xaxis={'visible': False}, yaxis={'visible': False}, 
                  xaxis2={'visible': False, 'side': 'top'}, yaxis2={'visible': False}, 
                  xaxis3={'visible': True, 'range': [0, 1]}, yaxis3={'visible': False}, 
                  xaxis4={'visible': True, 'range': [0, 1]}, yaxis4={'visible': False}, 
                  xaxis5={'visible': True, 'range': [0, 1]}, yaxis5={'visible': False}, 
                  xaxis6={'visible': False}, yaxis6={'visible': True, 'range': [0, 3], 'title': 'Loss'}, 
                  xaxis7={'visible': True}, yaxis7={'visible': True, 'side': 'right'}, 
                  xaxis8={'visible': True, 'title': 'epochs'}, yaxis8={'visible': True, 'range': [0, 100], 'title': 'Accuracy(%)'},
                  showlegend=False)

fig.update_coloraxes(showscale = False)
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.43,text="buildings", textangle = 270, showarrow = False, font=dict(size=16))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.52, text="forest", textangle = 270, showarrow = False, font=dict(size=16, color="white"))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.627,text="glacier", textangle = 270, showarrow = False, font=dict(size=16))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.766,text="mountain", textangle = 270, showarrow = False, font=dict(size=16))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.855,text="sea", textangle = 270, showarrow = False, font=dict(size=16, color="white"))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.967,text="street", textangle = 270, showarrow = False, font=dict(size=16, color="white"))

fig.add_annotation(xref='paper', yref='paper',x=0.54, y=0.3,text="validation loss", showarrow = False, font=dict(size=16, color="orange"))
fig.add_annotation(xref='paper', yref='paper',x=0.025, y=0.2,text="training loss", showarrow = False, font=dict(size=16, color="purple"))
fig.add_annotation(xref='paper', yref='paper',x=0.04, y=0.025,text="test accuracy", showarrow = False, font=dict(size=16, color="darkorange"))
fig.add_annotation(xref='paper', yref='paper',x=0.1, y=0.12,text="training accuracy", showarrow = False, font=dict(size=16, color="purple"))

fig.show()